# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Differentiable Wildfire Simulator written in JAX!

Wildfire Simulator written in JAX based on
https://github.com/IhmeGroup/Wildfire-TPU.
"""
import functools as ft
from typing import Tuple, TypeVar, Type, Union

from flax import struct
import jax
from jax import jit
from jax import lax
import jax.numpy as jnp
import numpy as np

from wildfire_perc_sim import utils

BST = TypeVar('BST', bound='BurnState')
SPT = TypeVar('SPT', bound='SimulatorProperties')


@struct.dataclass
class BurnState:
  lit: jnp.ndarray
  fire: jnp.ndarray
  heat: jnp.ndarray
  burnt: jnp.ndarray

  @classmethod
  def create(cls, lit, fire,
             heat):
    return cls(lit, fire, heat, jnp.zeros_like(lit))


@struct.dataclass
class FieldProperties:
  moisture: jnp.ndarray
  terrain: jnp.ndarray
  wind: jnp.ndarray
  density: jnp.ndarray


@struct.dataclass
class SimulatorParameters:
  slope_alpha: jnp.ndarray
  wind_alpha: jnp.ndarray
  moisture_alpha: jnp.ndarray


@struct.dataclass
class SimulatorProperties:
  """Wildfire Simulator Properties.

  Fixed Properties of the Wildfire Simulator. Should be constructed using
  `SimulatorProperties.create`. (These are fixed properties and should not be
  `vmap`ped over).

  Attributes:
    neighborhood_size: Neighbors affecting each cell.
    base_kernel: Kernel generated by `get_kernel(neighborhood_size)`.
    boundary_condition: Boundary Condition.
    nominal_ignition_heat: Amount of heat needed for igniting a cell.
    burn_duration: Maximum time for which a cell can remain ignited.
    sigmoid_coefficient: Smoothening coefficient for differentiable
      approximation of the simulation.
  """
  neighborhood_size: int
  base_kernel: jnp.ndarray
  boundary_condition: utils.BoundaryCondition
  nominal_ignition_heat: jnp.ndarray
  burn_duration: jnp.ndarray
  sigmoid_coefficient: float

  @classmethod
  def create(cls, neighborhood_size,
             boundary_condition,
             nominal_ignition_heat, burn_duration,
             sigmoid_coefficient = 15.0):
    base_kernel = get_kernel(neighborhood_size)
    return cls(neighborhood_size, base_kernel, boundary_condition,
               nominal_ignition_heat, burn_duration, sigmoid_coefficient)


def _wind_factor_kernel(wind_local,
                        alpha):
  """Compute wind weight factor.

  Computes the wind kernel for convolution using local wind tensor and
  alpha_wind.

  Args:
    wind_local: 5D local wind factor of size (field_length x field_width x
      nchannels x kernel_size x kernel_size).
    alpha: Sensitivity of fire to wind.

  Returns:
    4D Wind Kernel used for convolutions.

  Raises:
    ValueError: If input specification is not met, i.e., `wind_local.ndim != 5`
      or `wind_local.shape[3] != wind_local.shape[4]`.
  """
  if wind_local.ndim != 5:
    raise ValueError('`wind_local` is not 5 dimensional.')
  if wind_local.shape[3] != wind_local.shape[4]:
    raise ValueError(
        'Dimension 3 and Dimension 4 of `wind_local` should be of equal size.')

  wind_local = jnp.flip(wind_local, axis=2)

  # Normalized r vectors
  r = utils.radius_tensor(wind_local.shape[3])
  r_norm = utils.normalize(r)

  # Normalized U vectors
  mag_u = jnp.linalg.norm(wind_local, axis=2)
  u_norm = utils.normalize(wind_local)

  # Angle between U and r
  theta = jnp.arccos(jnp.einsum('ijklm, klm -> ijlm', u_norm, r_norm))

  # Wind factor
  z = 1 + 0.25 * mag_u
  e = jnp.sqrt(1 - (z**-2))
  a = mag_u / (1 + e)
  b = a / z
  gamma = jnp.sqrt((a**2) * (jnp.sin(theta)**2) - (a**2) * (e**2) *
                   (jnp.sin(theta)**2) + (b**2) * (jnp.cos(theta)**2))
  num = a * (b**2) * e * jnp.cos(theta) + a * b * gamma
  den = (a**2) * (jnp.sin(theta)**2) + (b**2) * (jnp.cos(theta)**2)

  frac = num / (den + utils.EPS)
  frac = jnp.nan_to_num(frac, nan=0, posinf=utils.INF, neginf=-utils.INF)

  psi_local = 1 + alpha * frac

  return psi_local


def get_wind_factor(field, alpha,
                    kernel_shape,
                    boundary_condition):
  """Invoke `wind_factor_kernel` on the `field`."""
  ks = (kernel_shape[0], kernel_shape[1])
  # Shape of the field
  field_shape = field.shape

  # Pad based on boundary condition
  field_padded = utils.pad_tensor_3d(field, kernel_shape, boundary_condition)

  # Generate patches
  patches = lax.conv_general_dilated_patches(
      jnp.expand_dims(field_padded, 0),
      ks, (1, 1), ((0, 0), (0, 0)),
      dimension_numbers=('NHWC', 'OIHW', 'NHWC')).reshape((*field_shape, *ks))

  return _wind_factor_kernel(patches, alpha)


def _slope_factor_kernel(slope_local,
                         alpha):
  """Compute slope weight factor.

  Computes the slope kernel for convolution using local slope tensor and
  alpha_slope.

  Args:
    slope_local: 5D local slope factor of size (field_length x field_width x
      nchannels x kernel_size x kernel_size).
    alpha: Sensitivity of fire to slope.

  Returns:
    4D Slope Kernel used for convolutions.

  Raises:
    ValueError: If input specification is not met, i.e., `slope_local.ndim != 5`
      or `slope_local.shape[3] != slope_local.shape[4]`.
  """
  if slope_local.ndim != 5:
    raise ValueError('`slope_local` is not 5 dimensional.')
  if slope_local.shape[3] != slope_local.shape[4]:
    raise ValueError(
        'Dimension 3 and Dimension 4 of `slope_local` should be of equal size.')

  slope_local = jnp.flip(slope_local, axis=2)

  # Normalized r vectors
  r = utils.radius_tensor(slope_local.shape[3])
  r_norm = utils.normalize(r)

  # Compute the slope factor
  phi_local = jnp.exp(alpha *
                      jnp.einsum('klm, ijklm -> ijlm', r_norm, slope_local))

  return phi_local


def get_slope_factor(
    field, alpha, kernel_shape,
    boundary_condition):
  """Invoke `slope_factor_kernel` on the `field`."""
  ks = (kernel_shape[0], kernel_shape[1])
  # Shape of the field
  field_shape = field.shape

  # Pad based on boundary condition
  field_padded = utils.pad_tensor_3d(field, kernel_shape, boundary_condition)

  # Generate patches
  patches = lax.conv_general_dilated_patches(
      jnp.expand_dims(field_padded, 0),
      ks, (1, 1), ((0, 0), (0, 0)),
      dimension_numbers=('NHWC', 'OIHW', 'NHWC')).reshape((*field_shape, *ks))

  return _slope_factor_kernel(patches, alpha)


def get_kernel(neighborhood_size):
  """Generate the kernel matrix for a given neighborhood size."""
  # Compute stencil to get list of indices
  stencil = utils.get_stencil(neighborhood_size)

  # Determine necessary kernel size
  max_index = jnp.amax(stencil)
  kernel_size = 2 * max_index + 1

  # Create kernel array
  kernel = jnp.zeros((kernel_size, kernel_size))

  # Set kernel elements corresponding to stencil to 1
  kernel = kernel.at[stencil[:, 0] + max_index,
                     stencil[:, 1] + max_index].set(1)

  return kernel


def generate_dynamic_kernel(kernel, wind_weight,
                            slope_weight):
  """Compute dynamic kernel given kernel and 4D weight array."""
  # Ensure all weight tensors are correctly shaped
  assert wind_weight.shape == slope_weight.shape

  dynamic_kernel = jnp.multiply(
      jnp.multiply(jnp.expand_dims(kernel, axis=(0, 1)), wind_weight),
      slope_weight)

  return dynamic_kernel


def _conv2d_dynamic(image, dynamic_kernel,
                    padding):
  """Compute convolution using dynamic kernel."""
  sz = dynamic_kernel.shape

  patches = lax.conv_general_dilated_patches(
      image,
      sz[2:], (1, 1),
      padding,
      dimension_numbers=('NHWC', 'OIHW', 'NHWC'))

  return jnp.sum(
      jnp.multiply(patches,
                   dynamic_kernel.reshape((1, sz[0], sz[1], sz[2] * sz[3]))),
      axis=3,
      keepdims=True)


@jit
def fire_active(lit):
  """Determine whether any batch in the case is active.

  Check if the fire in the batch is active. Either of the following
  conditions must be satisfied for no fire:
    - If there is `True` in `lit`.
    - If the fire has reached the walls, i.e., any of the boundaries
    contains at least one `True`.

  Args:
    lit: 3D tensor containings `True`/`1` in the region where there is fire
      currently.

  Returns:
    A boolean value determining if fire is active anywhere in the batch of
    fields.
  """
  assert lit.ndim == 3

  # Fire has exhausted if lit is entirely False for a given field
  exhausted = jnp.logical_not(jnp.any(lit, axis=(1, 2)))

  # Fire has exhausted if lit is True at any of the selected walls
  wall_w = jnp.any(lit[:, 0, :], axis=1)
  wall_e = jnp.any(lit[:, lit.shape[1] - 1, :], axis=1)
  wall_s = jnp.any(lit[:, :, 0], axis=1)
  wall_n = jnp.any(lit[:, :, lit.shape[2] - 1], axis=1)
  walls = jnp.zeros_like(wall_w[jnp.newaxis, :], dtype=bool)
  walls = jnp.concatenate([walls, wall_w[jnp.newaxis, :]], axis=0)
  walls = jnp.concatenate([walls, wall_e[jnp.newaxis, :]], axis=0)
  walls = jnp.concatenate([walls, wall_s[jnp.newaxis, :]], axis=0)
  walls = jnp.concatenate([walls, wall_n[jnp.newaxis, :]], axis=0)
  penetrated = jnp.any(walls, axis=0)

  return jnp.any(jnp.logical_not(jnp.logical_or(exhausted, penetrated)))  # pytype: disable=bad-return-type  # jnp-type


def get_ignition_heat(simprop,
                      fieldprop,
                      simparams,
                      finite = False):
  """Compute the ignition_heat over the field.

  The ignition_heat is computed using
  `dry_ignition_heat + moisture_ignition_heat` where
  `dry_ignition_heat = nominal_ignition_heat x density` and
  `moisture_ignition_heat = moisture_alpha x moisture`. Finally, places where
  there is no vegetation (i.e., density == 0) is set to an ignition_heat of
  infinity.

  Args:
    simprop: Simulator Properties
    fieldprop: Field Properties
    simparams: Simulator Parameters
    finite: Set patches which can't be burnt to utils.INF

  Returns:
    The ignition_heat
  """
  if simparams.moisture_alpha.ndim < fieldprop.moisture.ndim:
    moisture_alpha = jnp.expand_dims(
        simparams.moisture_alpha,
        axis=np.arange(simparams.moisture_alpha.ndim, fieldprop.moisture.ndim))
  else:
    moisture_alpha = simparams.moisture_alpha

  dry_ignition_heat = simprop.nominal_ignition_heat * fieldprop.density
  ignition_heat = dry_ignition_heat + moisture_alpha * fieldprop.moisture
  ignition_heat = ignition_heat.at[fieldprop.density == 0].set(
      utils.INF if finite else float('inf'))
  return ignition_heat


def start_fire(lit_source,
               ignition_heat):
  """Starts a fire at specified positions.

  Constructs 3 arrays -- `lit`, `heat` and `fire`. `lit` specifies which
  positions in the field are currently on fire, `fire` is a zero filled array,
  and `heat` is the heat content in each cell of the field.

  Args:
    lit_source: Locations where the fire is started.
    ignition_heat: Amount of heat needed for fire to start.

  Returns:
    BurnState object.
  """
  lit = lit_source
  fire = jnp.zeros_like(lit)
  heat = jnp.zeros_like(lit)
  where_lit = jnp.where(lit)
  heat = heat.at[where_lit].set(ignition_heat[where_lit])

  return BurnState.create(lit, heat, fire)


@ft.partial(jax.vmap, in_axes=(0, None, 0, 0))
def burn_step(bstate, simprop,
              dynamic_kernel,
              ignition_heat):
  """Propagate the fire for one time step.

  Simulate 1 step of the simulator.

  Args:
    bstate: Current BurnState of the field
    simprop: Simulator Properties
    dynamic_kernel: Convolution kernel for fire percolation (5D Array)
    ignition_heat: Heat needed to ignite each cell (3D Array)

  Returns:
    Updated BurnState of the field
  """
  # Increment fire where lit
  fire = jnp.where(bstate.lit, bstate.fire + 1, bstate.fire)

  # Find where burn duration has been reached
  burnt = fire >= simprop.burn_duration

  # Prepare Image tensor
  lit_float = bstate.lit.astype(jnp.float32)[:, :, None]

  # Pad tensor according to boundary condition
  lit_float_padded = utils.pad_tensor_3d(lit_float, simprop.base_kernel.shape,
                                         simprop.boundary_condition)

  # Compute dynamic kernel convolution to compute heat added at each point
  # Using VALID padding removes the initial padding automatically
  heat_added = _conv2d_dynamic(
      jnp.expand_dims(lit_float_padded, 0), dynamic_kernel, padding='VALID')

  # Add new heat
  heat = bstate.heat + heat_added[0, :, :, 0]

  # Ignite cells where heat is sufficient, as long as they are not already burnt
  lit = jnp.logical_and(
      jnp.greater_equal(heat, ignition_heat), jnp.logical_not(burnt))

  return bstate.replace(lit=lit, fire=fire, heat=heat, burnt=burnt)


@ft.partial(jax.vmap, in_axes=(None, 0, 0))
def parameterized_generate_dynamic_kernel(
    simprop, fieldprop,
    simparams):
  """Generate Dynamic Kernel used for Fire Propagation.

  Generates Dynamic Kernel for Fire Propagation

  Args:
    simprop: Simulator Properties.
    fieldprop: Field Properties.
    simparams: Simulator Parameters.

  Returns:
    5D dynamic kernel of size `batch x field_shape... x kernel_shape...`
  """
  # Generate fire spread kernel using `kernel = get_kernel(neighborhood_size)`
  kernel_shape = simprop.base_kernel.shape

  # Computing Slope/Terrain Factor
  slope = utils.gradient_o1(fieldprop.terrain, 1)
  slope_factor = get_slope_factor(slope, simparams.slope_alpha, kernel_shape,
                                  simprop.boundary_condition)

  # Computing Wind Factor
  wind_factor = get_wind_factor(fieldprop.wind, simparams.wind_alpha,
                                kernel_shape, simprop.boundary_condition)

  # Generate dynamic kernel
  dynamic_kernel = generate_dynamic_kernel(simprop.base_kernel, wind_factor,
                                           slope_factor)

  return dynamic_kernel


@ft.partial(jax.vmap, in_axes=(0, None, 0, 0))
def approximate_burn_step(bstate,
                          simprop,
                          dynamic_kernel,
                          ignition_heat):
  """Differentiable approximation of burn_step."""
  # Increment fire where lit
  fire = bstate.fire + bstate.lit

  # Find where burn duration has been reached
  burnt = utils.sigmoid(fire - simprop.burn_duration,
                        simprop.sigmoid_coefficient)

  # Prepare Image tensor
  lit_float = bstate.lit[:, :, None]

  # Pad tensor according to boundary condition
  lit_float_padded = utils.pad_tensor_3d(lit_float, simprop.base_kernel.shape,
                                         simprop.boundary_condition)

  # Compute dynamic kernel convolution to compute heat added at each point
  # Using VALID padding removes the initial padding automatically
  heat_added = _conv2d_dynamic(
      jnp.expand_dims(lit_float_padded, 0), dynamic_kernel, padding='VALID')

  # Add new heat
  heat = bstate.heat + heat_added[0, :, :, 0]

  # Ignite cells where heat is sufficient, as long as they are not already burnt
  lit = utils.sigmoid(heat - ignition_heat, simprop.sigmoid_coefficient)
  lit = jnp.abs(lit - burnt)

  return bstate.replace(lit=lit, fire=fire, heat=heat, burnt=burnt)
